from rknn.api import RKNN
from .base import Backend
import numpy as np
from geartrain.utils.import_utils import is_torch_available

if is_torch_available():
    import torch


class X86_RKNN_Backend(Backend):
    def __init__(self, model_path=None, device="rk3588", **kwargs):
        self.data_format = kwargs["data_format"]

        self.model_path = model_path
        self.fp16 = True
        super().__init__(device, "x86rknn")

    def _load_model(self):
        self.rknn = RKNN(verbose=False)
        self.rknn.config(
            mean_values=[[0, 0, 0]],
            std_values=[[1, 1, 1]],
            target_platform="rk3588",
            optimization_level=1,
        )

        self.rknn.load_onnx(model=self.model_path)
        ret = self.rknn.build(do_quantization=False)
        output_path = self.model_path.split(".")[0] + ".rknn"
        ret = self.rknn.export_rknn(output_path)

        ret = self.rknn.init_runtime()
        if ret != 0:
            print("Init runtime environment failed!")
            exit(ret)
        print("done")

    def _forward(self, model_inputs):
        if not isinstance(model_inputs, np.ndarray):
            _inputs = model_inputs.cpu().numpy()
        else:
            _inputs = model_inputs

        if self.fp16:
            _inputs = _inputs.astype(np.float16)

        y = self.rknn.inference(inputs=[_inputs], data_format=[self.data_format])

        if isinstance(y, (list, tuple)):
            return (
                self.from_numpy(y[0])
                if len(y) == 1
                else [self.from_numpy(x) for x in y]
            )
        else:
            return self.from_numpy(y)

    def from_numpy(self, x):
        """
        Convert a numpy array to a tensor.

        Args:
            x (np.ndarray): The array to be converted.

        Returns:
            (torch.Tensor): The converted tensor
        """
        if is_torch_available():
            return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x
        else:
            return x
